{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo - Difference between eval mode and training mode (CIFAR10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "import torchvision.utils\n",
    "from torchvision import models\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import torchattacks\n",
    "\n",
    "from models import Target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "batch_size = 128\n",
    "\n",
    "cifar10_test  = dsets.CIFAR10(root='./data', train=False,\n",
    "                              download=True, transform=transforms.ToTensor())\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(cifar10_test,\n",
    "                                          batch_size=batch_size,\n",
    "                                          shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Target().cuda()\n",
    "model.load_state_dict(torch.load(\"./checkpoints/target.pth\"))\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Eval mode v.s. Training mode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Save Progress: 100.00 % / Accuracy: 23.30 % / L2: 0.43309\r"
     ]
    }
   ],
   "source": [
    "atk = torchattacks.FGSM(model, eps=2/255)\n",
    "atk.save(data_loader=test_loader, save_path=None, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Save Progress: 100.00 % / Accuracy: 35.72 % / L2: 0.43311\r"
     ]
    }
   ],
   "source": [
    "atk.set_training_mode(True)\n",
    "atk.save(data_loader=test_loader, save_path=None, verbose=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
